from ACState.object_dict import ObjDict
import numpy as np




passive_forward = ["PASSIVE",
                    "PASSIVE_VAR",
                    "PASSIVE_RAW",
                    "PASSIVE_LIKELIHOOD"]

all_passive_forward = ["ALL_PASSIVE",
                    "ALL_PASSIVE_VAR",
                    "ALL_PASSIVE_RAW",
                    "ALL_PASSIVE_LIKELIHOOD"]

open_forward = ["ACTIVE_OPEN",
                "ACTIVE_OPEN_VAR",
                "ACTIVE_OPEN_RAW",
                "ACTIVE_OPEN_LIKELIHOOD",
                "ACTIVE_OPEN_GIVEN_LIKELIHOOD",
                "ACTIVE_OPEN_GRADIENT"]

all_open_forward = ["ALL_ACTIVE_OPEN",
                "ALL_ACTIVE_OPEN_VAR",
                "ALL_ACTIVE_OPEN_RAW",
                "ALL_ACTIVE_OPEN_LIKELIHOOD",
                "ALL_ACTIVE_OPEN_GIVEN_LIKELIHOOD",
                "ALL_ACTIVE_OPEN_GRADIENT", 
                ]


mask_forward = ["ACTIVE",
                "ACTIVE_VAR",
                "ACTIVE_RAW",
                "ACTIVE_LIKELIHOOD",
                "ACTIVE_GIVEN_LIKELIHOOD",
                "ACTIVE_GRADIENT"]

all_mask_forward = ["ALL_ACTIVE",
                "ALL_ACTIVE_VAR",
                "ALL_ACTIVE_RAW",
                "ALL_ACTIVE_LIKELIHOOD",
                "ALL_ACTIVE_GIVEN_LIKELIHOOD",
                "ALL_ACTIVE_GRADIENT"]

inter = ["INTERACTION", 
        "INTERACTION_RAW", 
        "INTERACTION_HOT", 
        "INTERACTION_BINARIES"]

all_inter = ["ALL_INTERACTION", 
            "ALL_INTERACTION_RAW", 
            "ALL_INTERACTION_HOT", 
            "ALL_INTERACTION_BINARIES"]

proximity_compute = [
                "PROXIMITY",
                "PROXIMITY_FLAT",
                "PROXIMITY_FULL",
                "PROXIMITY_ALL",
]

batch_compute = ["TRACE",# just sends back the trace values
            "DONE", # just sends back the done values
]

compute_names = (passive_forward + 
                    all_passive_forward +
                    open_forward +
                    all_open_forward +
                    mask_forward +
                    all_mask_forward +
                    inter +
                    all_inter + 
                    proximity_compute +
                    batch_compute)


def check_any(check_vals, check_in):
    for cv in check_vals:
        if cv in check_in:
            return True
    return False

def get_base_infer(inference_values):
    inter_use = "mask" if check_any(mask_forward, inference_values) else "probs"
    all_inter_use = "all_mask" if check_any(all_mask_forward, inference_values) else "all_probs"
    base_infer = {**{pf: "single_passive" for pf in passive_forward},
                    **{apf: "passive" for apf in all_passive_forward},
                    **{af: "full" for af in open_forward},
                    **{maf: "mask" for maf in mask_forward},
                    **{af: "all_full" for af in all_open_forward},
                    **{maf: "all_mask" for maf in all_mask_forward},
                    **{ic: inter_use for ic in inter},
                    **{aic: all_inter_use for aic in all_inter},
    }
    return base_infer

def get_additional_infer(inference_values):
            # attention
            # gradient # TODO: since this requires an optimization step, probably not immediately usable
            # hard
            # soft
            # mixed
            # flat
    return ["gradient"] if np.any([gv in grad_names for gv in inference_values]) else list()
    ACTIVE_OPEN_GRADIENT
    ALL_ACTIVE_OPEN_GRADIENT
    ACTIVE_GRADIENT
    ALL_ACTIVE_GRADIENT


compute_types = ObjDict({compute_names[i]: i for i in range(len(compute_names))})
num_compute_types = ObjDict({i: compute_names[i] for i in range(len(compute_names))})

mean_names = [
    compute_types["PASSIVE"],
    compute_types["ALL_PASSIVE"],
    compute_types["ACTIVE_OPEN"],
    compute_types["ALL_ACTIVE_OPEN"],
    compute_types["ACTIVE"],
    compute_types["ALL_ACTIVE"],
]

var_names = [
    compute_types["PASSIVE_VAR"],
    compute_types["ALL_PASSIVE_VAR"],
    compute_types["ACTIVE_OPEN_VAR"],
    compute_types["ALL_ACTIVE_OPEN_VAR"],
    compute_types["ACTIVE_VAR"],
    compute_types["ALL_ACTIVE_VAR"],
]

raw_names = [
    compute_types["PASSIVE_RAW"],
    compute_types["ALL_PASSIVE_RAW"],
    compute_types["ACTIVE_OPEN_RAW"],
    compute_types["ALL_ACTIVE_OPEN_RAW"],
    compute_types["ACTIVE_RAW"],
    compute_types["ALL_ACTIVE_RAW"],
]

like_names = [
    compute_types["PASSIVE_LIKELIHOOD"],
    compute_types["ALL_PASSIVE_LIKELIHOOD"],
    compute_types["ACTIVE_OPEN_LIKELIHOOD"],
    compute_types["ALL_ACTIVE_OPEN_LIKELIHOOD"],
    compute_types["ACTIVE_LIKELIHOOD"],
    compute_types["ALL_ACTIVE_LIKELIHOOD"]
]

given_names = [
    compute_types["ACTIVE_OPEN_GIVEN_LIKELIHOOD"],
    compute_types["ALL_ACTIVE_OPEN_GIVEN_LIKELIHOOD"],
    compute_types["ACTIVE_GIVEN_LIKELIHOOD"],
    compute_types["ALL_ACTIVE_GIVEN_LIKELIHOOD"]
]

grad_names = [
    compute_types["ACTIVE_OPEN_GRADIENT"],
    compute_types["ALL_ACTIVE_OPEN_GRADIENT"],
    compute_types["ACTIVE_GRADIENT"],
    compute_types["ALL_ACTIVE_GRADIENT"]
]

prox_names = [
                compute_types["PROXIMITY"],
                compute_types["PROXIMITY_FLAT"],
                compute_types["PROXIMITY_FULL"],
                compute_types["PROXIMITY_ALL"],
]
